package com.lzh.crypt.rsa;

import com.lzh.crypt.enums.RsaAlgorithmEnum;
import com.lzh.crypt.exception.ErrorCode;
import com.lzh.crypt.exception.RSAException;
import org.apache.commons.codec.binary.Base64;

import javax.crypto.Cipher;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;

/**
 * @author zhehen.lu
 * @date 2025/8/10 19:52
 */
public class RSAKeyPair {
    private RSAPrivateKey rsaPrivateKey;

    private RSAPublicKey rsaPublicKey;

    private String algorithm = "RSA";

    private String signAlgorithm = "MD5withRSA";

    public static RSAKeyPair getInstance() {
        KeyPairGenerator keyPairGenerator = null;
        try {
            keyPairGenerator = KeyPairGenerator.getInstance(RsaAlgorithmEnum.RSA.getValue());
            keyPairGenerator.initialize(2048, SecureRandom.getInstanceStrong());
            KeyPair keyPair = keyPairGenerator.generateKeyPair();
            return new RSAKeyPair(keyPair);
        } catch (NoSuchAlgorithmException e) {
            throw new RSAException(ErrorCode.RSA_INSTANTIATION_ERROR.getErrorMsg(),e);
        }
    }

    public static RSAKeyPair getInstance(String algorithm, String signAlgorithm) {
        KeyPairGenerator keyPairGenerator = null;
        try {
            keyPairGenerator = KeyPairGenerator.getInstance(algorithm);
            keyPairGenerator.initialize(2048, SecureRandom.getInstanceStrong());
            KeyPair keyPair = keyPairGenerator.generateKeyPair();
            return new RSAKeyPair(keyPair);
        } catch (NoSuchAlgorithmException e) {
            throw new RSAException(ErrorCode.RSA_INSTANTIATION_ERROR.getErrorMsg(),e);
        }
    }

    public RSAKeyPair(KeyPair keyPair, String algorithm, String signAlgorithm) {
        this.algorithm = algorithm;
        this.signAlgorithm = signAlgorithm;
        this.rsaPrivateKey = (RSAPrivateKey) keyPair.getPrivate();
        this.rsaPublicKey = (RSAPublicKey) keyPair.getPublic();
    }

    public RSAKeyPair(KeyPair keyPair) {
        this.rsaPrivateKey = (RSAPrivateKey) keyPair.getPrivate();
        this.rsaPublicKey = (RSAPublicKey) keyPair.getPublic();
    }



    public RSAKeyPair(String privateKeyBase64, String publicKeyBase64) throws Exception {
        PKCS8EncodedKeySpec pkcs8EncodedKeySpec = new PKCS8EncodedKeySpec(Base64.decodeBase64(privateKeyBase64));
        X509EncodedKeySpec x509EncodedKeySpec = new X509EncodedKeySpec(Base64.decodeBase64(publicKeyBase64));
        KeyFactory keyFactory = KeyFactory.getInstance(algorithm);
        this.rsaPublicKey = (RSAPublicKey) keyFactory.generatePublic(x509EncodedKeySpec);
        this.rsaPrivateKey = (RSAPrivateKey) keyFactory.generatePrivate(pkcs8EncodedKeySpec);
    }

    public RSAKeyPair(String privateKeyBase64, String publicKeyBase64, String algorithm, String signAlgorithm)
            throws Exception {
        this.algorithm = algorithm;
        this.signAlgorithm = signAlgorithm;
        PKCS8EncodedKeySpec pkcs8EncodedKeySpec = new PKCS8EncodedKeySpec(Base64.decodeBase64(privateKeyBase64));
        X509EncodedKeySpec x509EncodedKeySpec = new X509EncodedKeySpec(Base64.decodeBase64(publicKeyBase64));
        KeyFactory keyFactory = KeyFactory.getInstance(algorithm);
        this.rsaPublicKey = (RSAPublicKey) keyFactory.generatePublic(x509EncodedKeySpec);
        this.rsaPrivateKey = (RSAPrivateKey) keyFactory.generatePrivate(pkcs8EncodedKeySpec);
    }

    public RSAPrivateKey getRsaPrivateKey() {
        return rsaPrivateKey;
    }

    public void setRsaPrivateKey(RSAPrivateKey rsaPrivateKey) {
        this.rsaPrivateKey = rsaPrivateKey;
    }

    public RSAPublicKey getRsaPublicKey() {
        return rsaPublicKey;
    }

    public void setRsaPublicKey(RSAPublicKey rsaPublicKey) {
        this.rsaPublicKey = rsaPublicKey;
    }

    public String getPublickeyBase64() {
        return Base64.encodeBase64String(rsaPublicKey.getEncoded());
    }

    public String getPrivatekeyBase64() {
        return Base64.encodeBase64String(rsaPrivateKey.getEncoded());
    }

    public boolean verifyPublicKey(byte[] content, byte[] sign) {
        try {
            Signature signature = Signature.getInstance(signAlgorithm);
            signature.initVerify(rsaPublicKey);
            signature.update(content);
            return signature.verify(sign);
        } catch (Exception e) {
            throw new RSAException(ErrorCode.RSA_VALIDATION_ERROR.getErrorMsg(),e);
        }
    }

    public byte[] sign(byte[] content) {
        try {
            Signature signature = Signature.getInstance(signAlgorithm);
            signature.initSign(rsaPrivateKey);
            signature.update(content);
            return signature.sign();
        } catch (Exception e) {
            throw new RSAException(ErrorCode.RSA_SIGNATURE_ERROR.getErrorMsg(),e);
        }
    }

    public byte[] encrypt(byte[] bs) {
        try {
            Cipher cipher = Cipher.getInstance(algorithm);
            cipher.init(Cipher.ENCRYPT_MODE, rsaPublicKey);
            return cipher.doFinal(bs);
        } catch (Exception e) {
            throw new RSAException(ErrorCode.RSA_ENCRYPTION_EXCEPTION.getErrorMsg(),e);
        }
    }

    public byte[] encryptByPrivateKey(byte[] bs) {
        try {
            Cipher cipher = Cipher.getInstance(algorithm);
            cipher.init(Cipher.ENCRYPT_MODE, rsaPrivateKey);
            return cipher.doFinal(bs);
        } catch (Exception e) {
            throw new RSAException(ErrorCode.RSA_ENCRYPTION_EXCEPTION.getErrorMsg(),e);
        }
    }

    public byte[] decypt(byte[] bs) {
        try {
            Cipher cipher = Cipher.getInstance(algorithm);
            cipher.init(Cipher.DECRYPT_MODE, rsaPrivateKey);
            return cipher.doFinal(bs);
        } catch (Exception e) {
            throw new RSAException(ErrorCode.RSA_DECRYPTION_EXCEPTION.getErrorMsg(),e);
        }
    }

    public byte[] decyptByPublicKey(byte[] bs) {
        try {
            Cipher cipher = Cipher.getInstance(algorithm);
            cipher.init(Cipher.DECRYPT_MODE, rsaPublicKey);
            return cipher.doFinal(bs);
        } catch (Exception e) {
            throw new RSAException(ErrorCode.RSA_DECRYPTION_EXCEPTION.getErrorMsg(),e);
        }
    }
}
